{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates voice activity detection from a microphone's stream (online) and a given wav file (offline)  in NeMo."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The notebook requires PyAudio library to get a signal from an audio device.\n",
    "For Ubuntu, please run the following commands to install it:\n",
    "```\n",
    "sudo apt-get install -y portaudio19-dev\n",
    "pip install pyaudio\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import nemo\n",
    "import nemo.collections.asr as nemo_asr\n",
    "import numpy as np\n",
    "import pyaudio as pa\n",
    "import time\n",
    "\n",
    "import librosa\n",
    "import IPython.display as ipd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Architecture and Weights\n",
    "\n",
    "The model architecture is defined in a YAML file available in the config directory. MatchboxNet 3x1x64 has been trained on the [Google Speech Commands v2 dataset](https://arxiv.org/abs/1804.03209) and [freesound](https://freesound.org), and these weights are available on NGC. They will automatically be downloaded if not found."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_YAML = '../configs/quartznet_vad_3x1.yaml'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the checkpoint files\n",
    "\n",
    "base_checkpoint_path = './checkpoints/matchboxnet_3x1x1/'\n",
    "CHECKPOINT_ENCODER = os.path.join(base_checkpoint_path, 'JasperEncoder-STEP-90800.pt')\n",
    "CHECKPOINT_DECODER = os.path.join(base_checkpoint_path, 'JasperDecoderForClassification-STEP-90800.pt')\n",
    "\n",
    "if not os.path.exists(base_checkpoint_path):\n",
    "    os.makedirs(base_checkpoint_path)\n",
    "    \n",
    "if not os.path.exists(CHECKPOINT_ENCODER):\n",
    "    !wget https://api.ngc.nvidia.com/v2/models/nvidia/vad_matchboxnet_3x1x1/versions/1/files/JasperEncoder-STEP-90800.pt -P {base_checkpoint_path};\n",
    "if not os.path.exists(CHECKPOINT_DECODER):\n",
    "    !wget https://api.ngc.nvidia.com/v2/models/nvidia/vad_matchboxnet_3x1x1/versions/1/files/JasperDecoderForClassification-STEP-90800.pt -P {base_checkpoint_path};\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct the Neural Modules and the eval graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ruamel.yaml import YAML\n",
    "yaml = YAML(typ=\"safe\")\n",
    "with open(MODEL_YAML) as f:\n",
    "    model_definition = yaml.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neural_factory = nemo.core.NeuralModuleFactory(\n",
    "    placement=nemo.core.DeviceType.GPU)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a Neural Module to iterate over audio\n",
    "\n",
    "Here we define a custom Neural Module which acts as an iterator over a stream of audio that is supplied to it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.backends.pytorch.nm import DataLayerNM\n",
    "from nemo.core.neural_types import NeuralType, AudioSignal, LengthsType\n",
    "import torch\n",
    "\n",
    "# simple data layer to pass audio signal\n",
    "class AudioDataLayer(DataLayerNM):\n",
    "    @property\n",
    "    def output_ports(self):\n",
    "        return {\n",
    "            'audio_signal': NeuralType(('B', 'T'), AudioSignal(freq=self._sample_rate)),\n",
    "            'a_sig_length': NeuralType(tuple('B'), LengthsType()),\n",
    "        }\n",
    "\n",
    "    def __init__(self, sample_rate):\n",
    "        super().__init__()\n",
    "        self._sample_rate = sample_rate\n",
    "        self.output = True\n",
    "        \n",
    "    def __iter__(self):\n",
    "        return self\n",
    "    \n",
    "    def __next__(self):\n",
    "        if not self.output:\n",
    "            raise StopIteration\n",
    "        self.output = False\n",
    "        return torch.as_tensor(self.signal, dtype=torch.float32), \\\n",
    "               torch.as_tensor(self.signal_shape, dtype=torch.int64)\n",
    "        \n",
    "    def set_signal(self, signal):\n",
    "        self.signal = np.reshape(signal.astype(np.float32)/32768., [1, -1])\n",
    "        self.signal_shape = np.expand_dims(self.signal.size, 0).astype(np.int64)\n",
    "        self.output = True\n",
    "\n",
    "    def __len__(self):\n",
    "        return 1\n",
    "\n",
    "    @property\n",
    "    def dataset(self):\n",
    "        return None\n",
    "\n",
    "    @property\n",
    "    def data_iterator(self):\n",
    "        return self"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Instantiate the Neural Modules\n",
    "\n",
    "We now instantiate the neural modules and the encoder and decoder, set the weights of these models with the downloaded pretrained weights and construct the DAG to evaluate MatchboxNet on audio streams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate necessary neural modules\n",
    "data_layer = AudioDataLayer(sample_rate=model_definition['sample_rate'])\n",
    "\n",
    "data_preprocessor = nemo_asr.AudioToMFCCPreprocessor(\n",
    "    **model_definition['AudioToMFCCPreprocessor'])\n",
    "\n",
    "jasper_encoder = nemo_asr.JasperEncoder(\n",
    "    **model_definition['JasperEncoder'])\n",
    "\n",
    "jasper_decoder = nemo_asr.JasperDecoderForClassification(\n",
    "    feat_in=model_definition['JasperEncoder']['jasper'][-1]['filters'],\n",
    "    num_classes=len(model_definition['labels']))\n",
    "\n",
    "# load pre-trained model\n",
    "jasper_encoder.restore_from(CHECKPOINT_ENCODER)\n",
    "jasper_decoder.restore_from(CHECKPOINT_DECODER)\n",
    "\n",
    "# Define inference DAG\n",
    "audio_signal, audio_signal_len = data_layer()\n",
    "processed_signal, processed_signal_len = data_preprocessor(\n",
    "    input_signal=audio_signal,\n",
    "    length=audio_signal_len)\n",
    "encoded, encoded_len = jasper_encoder(audio_signal=processed_signal,\n",
    "                                      length=processed_signal_len)\n",
    "log_probs = jasper_decoder(encoder_output=encoded)\n",
    "\n",
    "# inference method for audio signal (single instance)\n",
    "def infer_signal(self, signal):\n",
    "    data_layer.set_signal(signal)\n",
    "    tensors = self.infer([log_probs], verbose=False)\n",
    "    logits = tensors[0][0]\n",
    "    return logits\n",
    "\n",
    "neural_factory.infer_signal = infer_signal.__get__(neural_factory)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FrameASR: Helper class for streaming inference\n",
    "Here we adopt FrameASR for streaming inference for voice activatity detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class for streaming frame-based ASR\n",
    "# 1) use reset() method to reset FrameASR's state\n",
    "# 2) call transcribe(frame) to do ASR on\n",
    "#    contiguous signal's frames\n",
    "class FrameASR:\n",
    "    \n",
    "    def __init__(self, neural_factory, model_definition,\n",
    "                 frame_len=2, frame_overlap=2.5, \n",
    "                 offset=10):\n",
    "        '''\n",
    "        Args:\n",
    "          frame_len: frame's duration, seconds\n",
    "          frame_overlap: duration of overlaps before and after current frame, seconds\n",
    "          offset: number of symbols to drop for smooth streaming\n",
    "        '''\n",
    "        self.vocab = list(model_definition['labels'])\n",
    "        self.vocab.append('_')\n",
    "        \n",
    "        self.sr = model_definition['sample_rate']\n",
    "        self.frame_len = frame_len\n",
    "        self.n_frame_len = int(frame_len * self.sr)\n",
    "        self.frame_overlap = frame_overlap\n",
    "        self.n_frame_overlap = int(frame_overlap * self.sr)\n",
    "        timestep_duration = model_definition['AudioToMFCCPreprocessor']['window_stride']\n",
    "        for block in model_definition['JasperEncoder']['jasper']:\n",
    "            timestep_duration *= block['stride'][0] ** block['repeat']\n",
    "        self.buffer = np.zeros(shape=2*self.n_frame_overlap + self.n_frame_len,\n",
    "                               dtype=np.float32)\n",
    "        self.offset = offset\n",
    "        self.reset()\n",
    "        \n",
    "    def _decode(self, frame, offset=0):\n",
    "        assert len(frame)==self.n_frame_len\n",
    "        self.buffer[:-self.n_frame_len] = self.buffer[self.n_frame_len:]\n",
    "        self.buffer[-self.n_frame_len:] = frame\n",
    "        logits = neural_factory.infer_signal(self.buffer).to('cpu').numpy()[0]\n",
    "        decoded = self._greedy_decoder(\n",
    "            logits, \n",
    "            self.vocab\n",
    "        )\n",
    "        return decoded[:len(decoded)-offset]\n",
    "    \n",
    "    def transcribe(self, frame=None):\n",
    "        if frame is None:\n",
    "            frame = np.zeros(shape=self.n_frame_len, dtype=np.float32)\n",
    "        if len(frame) < self.n_frame_len:\n",
    "            frame = np.pad(frame, [0, self.n_frame_len - len(frame)], 'constant')\n",
    "        unmerged = self._decode(frame, self.offset)\n",
    "        return unmerged\n",
    "    \n",
    "    def reset(self):\n",
    "        '''\n",
    "        Reset frame_history and decoder's state\n",
    "        '''\n",
    "        self.buffer=np.zeros(shape=self.buffer.shape, dtype=np.float32)\n",
    "        self.prev_char = ''\n",
    "\n",
    "    @staticmethod\n",
    "    def _greedy_decoder(logits, vocab):\n",
    "        s = ''\n",
    "        s = []\n",
    "        if logits.shape[0]:\n",
    "            probs = torch.softmax(torch.as_tensor(logits), dim=-1)\n",
    "            probas, preds = torch.max(probs, dim=-1)\n",
    "            s = [preds.item(), str(vocab[preds]), probs[0].item(), probs[1].item(), str(logits)]\n",
    "        return s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What classes can this model recognize?\n",
    "\n",
    "Before we begin inference on the actual audio stream, lets look at what are the classes this model was trained to recognize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = model_definition['labels']\n",
    "print(labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Listening to audio stream and perform inference using FrameASR"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Offline Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can experiment with differents **STEP** and **WINDOW_SIZE** for streaming VAD inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "STEP_LIST =        [0.01, 0.01, 0.01]\n",
    "WINDOW_SIZE_LIST = [0.25, 0.20, 0.15]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wave\n",
    "\n",
    "def offline_inference(wave_file, STEP = 0.025, WINDOW_SIZE = 0.5):\n",
    "    \n",
    "    FRAME_LEN = STEP # infer every STEP seconds \n",
    "    CHANNELS = 1 # number of audio channels (expect mono signal)\n",
    "    RATE = 16000 # sample rate, Hz\n",
    "   \n",
    "    CHUNK_SIZE = int(FRAME_LEN*RATE)\n",
    "    asr = FrameASR(neural_factory, model_definition,\n",
    "                   frame_len=FRAME_LEN, frame_overlap = (WINDOW_SIZE-FRAME_LEN)/2,\n",
    "                   offset=0)\n",
    "\n",
    "    wf = wave.open(wave_file, 'rb')\n",
    "    p = pa.PyAudio()\n",
    "\n",
    "    empty_counter = 0\n",
    "\n",
    "    preds = []\n",
    "    proba_b = []\n",
    "    proba_s = []\n",
    "    \n",
    "    def callback(in_data, frame_count, time_info, status):\n",
    "        data = wf.readframes(frame_count)\n",
    "        global empty_counter\n",
    "        signal = np.frombuffer(data, dtype=np.int16)\n",
    "        result = asr.transcribe(signal)\n",
    "\n",
    "        preds.append(result[0])\n",
    "        proba_b.append(result[2])\n",
    "        proba_s.append(result[3])\n",
    "        if len(result):\n",
    "            print(result,end='\\n')\n",
    "            empty_counter = 3\n",
    "        elif empty_counter > 0:\n",
    "            empty_counter -= 1\n",
    "            if empty_counter == 0:\n",
    "                print(' ',end='')\n",
    "\n",
    "        return (data, pa.paContinue)\n",
    "\n",
    "    stream = p.open(format=p.get_format_from_width(wf.getsampwidth()),\n",
    "                    channels=CHANNELS,\n",
    "                    rate=RATE,\n",
    "                    output = True,\n",
    "                    stream_callback=callback,\n",
    "                    frames_per_buffer=CHUNK_SIZE) # Specifies the number of frames per buffer.\n",
    " \n",
    "    stream.start_stream()\n",
    "\n",
    "    while stream.is_active():\n",
    "        time.sleep(0.1)\n",
    "\n",
    "    stream.stop_stream()\n",
    "    stream.close()\n",
    "    p.terminate()\n",
    "\n",
    "    asr.reset()\n",
    "    return preds, proba_b, proba_s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Here we show an example of offline streaming inference\n",
    "You can use your file or download the provided toy dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "toy_data = './vad'\n",
    "if not os.path.exists(toy_data):\n",
    "    !wget -c \"https://github.com/NVIDIA/NeMo/blob/master/tests/data/vad.tar.xz?raw=true\" -O vad.tar.xz \n",
    "    !tar -xvf vad.tar.xz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wave_file = './vad/welcome_noisy.wav'\n",
    "CHANNELS = 1\n",
    "RATE = 16000\n",
    "audio, sample_rate = librosa.load(wave_file, sr=RATE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results = []\n",
    "for STEP, WINDOW_SIZE in zip(STEP_LIST, WINDOW_SIZE_LIST):\n",
    "    print(f'====== STEP is {STEP}s, WINDOW_SIZE is {WINDOW_SIZE}s ====== ')\n",
    "    preds, proba_b, proba_s = offline_inference(wave_file, STEP, WINDOW_SIZE)\n",
    "    results.append([STEP, WINDOW_SIZE, preds, proba_b, proba_s])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from pylab import *\n",
    "import numpy as np\n",
    "import librosa.display\n",
    "plt.figure(figsize=[16,10])\n",
    "plt.title('Audio, Preictions and Probas')\n",
    "plt.rcParams.update({'font.size': 10, 'font.family': 'sans-serif'})\n",
    "subplots_adjust(hspace=2.00)\n",
    "\n",
    "\n",
    "FRAME_LEN = STEP_LIST[0]\n",
    "len_pred = len(results[0][2]) \n",
    "\n",
    "num = len(results)\n",
    "for i,v in enumerate(range(num + 1)):\n",
    "    v = v + 1\n",
    "    if  v > len(results):\n",
    "\n",
    "        ax = plt.subplot(num + 2, 1, v)\n",
    "        S = librosa.feature.melspectrogram(y=audio, sr=sample_rate, n_mels=128,\n",
    "                                  fmax=8000)\n",
    "        S_dB = librosa.power_to_db(S, ref=np.max)\n",
    "        librosa.display.specshow(S_dB, x_axis='time', y_axis='mel', \n",
    "                                 sr=sample_rate, fmax=8000)\n",
    "        ax.set_title('Mel-frequency spectrogram')\n",
    "        ax.grid()\n",
    "\n",
    "        ax = plt.subplot(num + 2, 1, v + 1)\n",
    "        ax.plot(np.arange(audio.size) / sample_rate, audio, 'b')\n",
    "        ax.set_xlim([-0.01,  len_pred * FRAME_LEN])\n",
    "        ax.set_ylabel('Signal')\n",
    "        ax.set_xlabel('Time, seconds')\n",
    "        ax.set_title(f'File: {str(wave_file)}')\n",
    "        ax.set_ylim([-0.5,  0.5])\n",
    "        ax.grid()\n",
    "    else:\n",
    "        ax = plt.subplot(num + 2, 1, v)\n",
    "        ax.plot(results[i][2], 'r', label='pred')\n",
    "        ax.plot(results[i][3], 'g--', label='prob for background')\n",
    "        ax.plot(results[i][4], 'b--', label='prob for speech')\n",
    "        ax.set_xlim([0, len_pred])\n",
    "        ax.set_title(f'step {results[i][0]}s, buffer size {results[i][1]}s')\n",
    "        ax.set_ylabel('Preds and Probas')\n",
    "        ax.set_xlabel('Segments')\n",
    "        ax.grid()\n",
    "        legend = ax.legend(loc='lower left', shadow=True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import librosa\n",
    "ipd.Audio(audio, rate=sample_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Online inference through microphone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "STEP = 0.01 \n",
    "WINDOW_SIZE = 0.20\n",
    "CHANNELS = 1 \n",
    "RATE = 16000\n",
    "\n",
    "CHUNK_SIZE = int(STEP * RATE)\n",
    "asr = FrameASR(neural_factory, model_definition,\n",
    "               frame_len=STEP, frame_overlap=(WINDOW_SIZE - FRAME_LEN) / 2, \n",
    "               offset=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = pa.PyAudio()\n",
    "print('Available audio input devices:')\n",
    "for i in range(p.get_device_count()):\n",
    "    dev = p.get_device_info_by_index(i)\n",
    "    if dev.get('maxInputChannels'):\n",
    "        print(i, dev.get('name'))\n",
    "print('Please type input device ID:')\n",
    "dev_idx = int(input())\n",
    "\n",
    "empty_counter = 0\n",
    "\n",
    "def callback(in_data, frame_count, time_info, status):\n",
    "    global empty_counter\n",
    "    signal = np.frombuffer(in_data, dtype=np.int16)\n",
    "    text = asr.transcribe(signal)\n",
    "    if len(text):\n",
    "        print(text,end='\\n')\n",
    "        empty_counter = 3\n",
    "    elif empty_counter > 0:\n",
    "        empty_counter -= 1\n",
    "        if empty_counter == 0:\n",
    "            print(' ',end='')\n",
    "    return (in_data, pa.paContinue)\n",
    "\n",
    "stream = p.open(format=pa.paInt16,\n",
    "                channels=CHANNELS,\n",
    "                rate=RATE,\n",
    "                input=True,\n",
    "                input_device_index=dev_idx,\n",
    "                stream_callback=callback,\n",
    "                frames_per_buffer=CHUNK_SIZE)\n",
    "\n",
    "print('Listening...')\n",
    "\n",
    "stream.start_stream()\n",
    "\n",
    "while stream.is_active():\n",
    "    time.sleep(0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stream.stop_stream()\n",
    "stream.close()\n",
    "p.terminate()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
